{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "If you were not here for Lab 12, and need to install the graphviz package:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install --user graphviz" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Lab 13 - Decision Trees for regression\n", "\n", "For this lab, we will return to the insurance data from Labs 7 and 8. Recall we are trying to predict the insurance cost, a quantitative value. \n", "\n", "If you don't have the dataset, download it from GitHub: [https://github.com/stedy/Machine-Learning-with-R-datasets/blob/master/insurance.csv](https://github.com/stedy/Machine-Learning-with-R-datasets/blob/master/insurance.csv)\n", "\n", "In this data, each row represents an insurance policy and the 7 columns contain the following information about it:\n", "- age: age of policy holder\n", "- sex: sex of policy holder\n", "- bmi: boday mass index (bmi) of policy holder. bmi is a (sometimes unreliable) measurement of body fat in adults\n", "- children: number of children (dependents) on the policy\n", "- smoker: whether the policy holder is a smoker\n", "- region: region of the country the policy holder lives in\n", "- charges: price for insurance policy" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from sklearn import tree\n", "import graphviz\n", "from graphviz import Source\n", " \n", "from sklearn.model_selection import train_test_split\n", "\n", "from sklearn.tree import export_graphviz\n", "import sklearn.metrics as met\n", "from sklearn.metrics import confusion_matrix\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Read the data into a dataframe and display it to make sure it was read in correctly:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesexbmichildrensmokerregioncharges
019female27.9000yessouthwest16884.92400
118male33.7701nosoutheast1725.55230
228male33.0003nosoutheast4449.46200
333male22.7050nonorthwest21984.47061
432male28.8800nonorthwest3866.85520
\n", "
" ], "text/plain": [ " age sex bmi children smoker region charges\n", "0 19 female 27.900 0 yes southwest 16884.92400\n", "1 18 male 33.770 1 no southeast 1725.55230\n", "2 28 male 33.000 3 no southeast 4449.46200\n", "3 33 male 22.705 0 no northwest 21984.47061\n", "4 32 male 28.880 0 no northwest 3866.85520" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "insurance = pd.read_csv(\"insurance.csv\")\n", "insurance.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sci-kit learn decision trees require numeric data. How can we convert the categorical columns into numeric data? \n", "Hint: see Lab 8" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "insurance = pd.get_dummies(insurance, columns = [\"sex\", \"smoker\", \"region\"], drop_first = True)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agebmichildrenchargessex_malesmoker_yesregion_northwestregion_southeastregion_southwest
01927.900016884.9240001001
11833.77011725.5523010010
22833.00034449.4620010010
33322.705021984.4706110100
43228.88003866.8552010100
\n", "
" ], "text/plain": [ " age bmi children charges sex_male smoker_yes region_northwest \\\n", "0 19 27.900 0 16884.92400 0 1 0 \n", "1 18 33.770 1 1725.55230 1 0 0 \n", "2 28 33.000 3 4449.46200 1 0 0 \n", "3 33 22.705 0 21984.47061 1 0 1 \n", "4 32 28.880 0 3866.85520 1 0 1 \n", "\n", " region_southeast region_southwest \n", "0 0 1 \n", "1 1 0 \n", "2 1 0 \n", "3 0 0 \n", "4 0 0 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "insurance.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fitting a decision tree with sci-kit learn\n", "\n", "We can get just the independent variables (x's) using the following:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agebmichildrensex_malesmoker_yesregion_northwestregion_southeastregion_southwest
01927.900001001
11833.770110010
22833.000310010
33322.705010100
43228.880010100
\n", "
" ], "text/plain": [ " age bmi children sex_male smoker_yes region_northwest \\\n", "0 19 27.900 0 0 1 0 \n", "1 18 33.770 1 1 0 0 \n", "2 28 33.000 3 1 0 0 \n", "3 33 22.705 0 1 0 1 \n", "4 32 28.880 0 1 0 1 \n", "\n", " region_southeast region_southwest \n", "0 0 1 \n", "1 1 0 \n", "2 1 0 \n", "3 0 0 \n", "4 0 0 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = insurance.iloc[:,[0,1,2,4,5,6,7,8]]\n", "X.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we created the decision tree variable (object) and then fit it to our data:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "reg = tree.DecisionTreeRegressor(max_depth = 5)\n", "reg = reg.fit(X, insurance[\"charges\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you are running Jupyter Hub on your own computer, you may be able to display the decision tree by:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tree.plot_tree(reg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you are using the Jupyter Hub server, run the following code (which will give an error):" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "scrolled": true }, "outputs": [ { "ename": "PermissionError", "evalue": "[Errno 13] Permission denied", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mPermissionError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdot_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_graphviz\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_file\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mgraph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgraphviz\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSource\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdot_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mgraph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"insurance.dot\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/.local/lib/python3.4/site-packages/graphviz/files.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, filename, directory, view, cleanup, format, renderer, formatter)\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0mformat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_format\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 188\u001b[0;31m \u001b[0mrendered\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_engine\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilepath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformatter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcleanup\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/.local/lib/python3.4/site-packages/graphviz/backend.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(engine, format, filepath, renderer, formatter, quiet)\u001b[0m\n\u001b[1;32m 181\u001b[0m \"\"\"\n\u001b[1;32m 182\u001b[0m \u001b[0mcmd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrendered\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcommand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mengine\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilepath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformatter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 183\u001b[0;31m \u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcmd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcapture_output\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcheck\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquiet\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mquiet\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 184\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mrendered\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/.local/lib/python3.4/site-packages/graphviz/backend.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(cmd, input, capture_output, check, quiet, **kwargs)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m \u001b[0mproc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msubprocess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcmd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstartupinfo\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mget_startupinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mOSError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0merrno\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0merrno\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mENOENT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.4/subprocess.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, args, bufsize, executable, stdin, stdout, stderr, preexec_fn, close_fds, shell, cwd, env, universal_newlines, startupinfo, creationflags, restore_signals, start_new_session, pass_fds)\u001b[0m\n\u001b[1;32m 854\u001b[0m \u001b[0mc2pread\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc2pwrite\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 855\u001b[0m \u001b[0merrread\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merrwrite\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 856\u001b[0;31m restore_signals, start_new_session)\n\u001b[0m\u001b[1;32m 857\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 858\u001b[0m \u001b[0;31m# Cleanup if the child failed starting.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.4/subprocess.py\u001b[0m in \u001b[0;36m_execute_child\u001b[0;34m(self, args, executable, preexec_fn, close_fds, pass_fds, cwd, env, startupinfo, creationflags, shell, p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite, restore_signals, start_new_session)\u001b[0m\n\u001b[1;32m 1462\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1463\u001b[0m \u001b[0merr_msg\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m': '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mrepr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morig_executable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1464\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mchild_exception_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merrno_num\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merr_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1465\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mchild_exception_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merr_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1466\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mPermissionError\u001b[0m: [Errno 13] Permission denied" ] } ], "source": [ "dot_data = tree.export_graphviz(reg, out_file=None) \n", "graph = graphviz.Source(dot_data) \n", "graph.render(\"insurance.dot\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, despite the error, there should now be a file called happiness.dot in your directory. To view the fitted decision tree, open the happiness.dot file in Jupyter and copy the text. Paste this text into the text box at [http://www.webgraphviz.com](http://www.webgraphviz.com) and click the \"Generate graph!\" button at the bottom.\n", "\n", "The column names have been replaced by `X[0], X[1], ..., X[7]`. Run the following code to change `X[0], X[1], ..., X[7]` to the column names in insurance.dot." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "with open (\"insurance.dot\", \"r\") as fin:\n", " with open(\"insurance_fixed.dot\",\"w\") as fout:\n", " for line in fin.readlines():\n", " line = line.replace(\"X[0]\",\"age\")\n", " line = line.replace(\"X[1]\",\"bmi\")\n", " line = line.replace(\"X[2]\",\"children\")\n", " line = line.replace(\"X[3]\",\"sex_male\")\n", " line = line.replace(\"X[4]\",\"smoker_yes\")\n", " line = line.replace(\"X[5]\",\"region_northwest\") \n", " line = line.replace(\"X[4]\",\"region_southeast\")\n", " line = line.replace(\"X[5]\",\"region_southwest\")\n", " fout.write(line)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Copy the contents of insurance_fixed.dot into the textbox in [http://www.webgraphviz.com](http://www.webgraphviz.com) to display the decision tree with the column names. How does it compare the the decision tree you made?\n", "\n", "What happens if you change the `max_depth` parameter to 5 in DecisionTreeRegressor?\n", "\n", "Look at the leaves of your new tree. What's the smallest sample? \n", "\n", "A few of the leaves only have 1 sample. How do you think this tree would work on other insurance data?\n", "\n", "The single samples are a sign of over-fitting, and to fix it we can make `max_depth` smaller (but too small and our model will not be as good as it could be).\n", "\n", "### Testing and training data\n", "\n", "To figure out what `max_depth` should be, let's split our data into training and testing data. " ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, insurance[\"charges\"], test_size=0.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create a decision tree with `max_depth = 3` from the training data:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "reg3 = tree.DecisionTreeRegressor(max_depth = 3)\n", "reg3 = reg3.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make predictions for the test data:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "array([13785.78482646, 6161.11114528, 18723.54895898, 6161.11114528,\n", " 6161.11114528, 13785.78482646, 2795.03725199, 2795.03725199,\n", " 13785.78482646, 6161.11114528, 2795.03725199, 2795.03725199,\n", " 6161.11114528, 13785.78482646, 6161.11114528, 18723.54895898,\n", " 10351.84377925, 13785.78482646, 13785.78482646, 38722.41603063,\n", " 38722.41603063, 13785.78482646, 2795.03725199, 6161.11114528,\n", " 6161.11114528, 6161.11114528, 13785.78482646, 13785.78482646,\n", " 10351.84377925, 6161.11114528, 18723.54895898, 45606.72260404,\n", " 18723.54895898, 2795.03725199, 10351.84377925, 6161.11114528,\n", " 6161.11114528, 13785.78482646, 45606.72260404, 6161.11114528,\n", " 10351.84377925, 45606.72260404, 18723.54895898, 6161.11114528,\n", " 6161.11114528, 6161.11114528, 2795.03725199, 10351.84377925,\n", " 18723.54895898, 38722.41603063, 45606.72260404, 13785.78482646,\n", " 13785.78482646, 13785.78482646, 6161.11114528, 2795.03725199,\n", " 2795.03725199, 6161.11114528, 24603.2390669 , 13785.78482646,\n", " 13785.78482646, 13785.78482646, 6161.11114528, 24603.2390669 ,\n", " 2795.03725199, 2795.03725199, 13785.78482646, 6161.11114528,\n", " 13785.78482646, 6161.11114528, 10351.84377925, 6161.11114528,\n", " 10351.84377925, 6161.11114528, 6161.11114528, 2795.03725199,\n", " 2795.03725199, 6161.11114528, 18723.54895898, 13785.78482646,\n", " 6161.11114528, 6161.11114528, 24603.2390669 , 6161.11114528,\n", " 24603.2390669 , 24603.2390669 , 6161.11114528, 13785.78482646,\n", " 13785.78482646, 13785.78482646, 18723.54895898, 6161.11114528,\n", " 6161.11114528, 13785.78482646, 38722.41603063, 13785.78482646,\n", " 13785.78482646, 10351.84377925, 24603.2390669 , 6161.11114528,\n", " 6161.11114528, 13785.78482646, 6161.11114528, 18723.54895898,\n", " 10351.84377925, 13785.78482646, 18723.54895898, 2795.03725199,\n", " 13785.78482646, 6161.11114528, 6161.11114528, 6161.11114528,\n", " 24603.2390669 , 6161.11114528, 13785.78482646, 13785.78482646,\n", " 6161.11114528, 6161.11114528, 45606.72260404, 13785.78482646,\n", " 6161.11114528, 38722.41603063, 6161.11114528, 6161.11114528,\n", " 6161.11114528, 2795.03725199, 6161.11114528, 6161.11114528,\n", " 2795.03725199, 13785.78482646, 6161.11114528, 6161.11114528,\n", " 45606.72260404, 13785.78482646, 38722.41603063, 13785.78482646,\n", " 10351.84377925, 13785.78482646, 13785.78482646, 6161.11114528,\n", " 13785.78482646, 2795.03725199, 2795.03725199, 6161.11114528,\n", " 10351.84377925, 18723.54895898, 10351.84377925, 2795.03725199,\n", " 10351.84377925, 18723.54895898, 10351.84377925, 2795.03725199,\n", " 2795.03725199, 6161.11114528, 6161.11114528, 10351.84377925,\n", " 10351.84377925, 6161.11114528, 6161.11114528, 10351.84377925,\n", " 10351.84377925, 18723.54895898, 13785.78482646, 38722.41603063,\n", " 6161.11114528, 24603.2390669 , 10351.84377925, 13785.78482646,\n", " 6161.11114528, 6161.11114528, 6161.11114528, 10351.84377925,\n", " 6161.11114528, 13785.78482646, 6161.11114528, 13785.78482646,\n", " 13785.78482646, 45606.72260404, 38722.41603063, 6161.11114528,\n", " 2795.03725199, 24603.2390669 , 10351.84377925, 38722.41603063,\n", " 10351.84377925, 38722.41603063, 6161.11114528, 10351.84377925,\n", " 6161.11114528, 10351.84377925, 6161.11114528, 13785.78482646,\n", " 38722.41603063, 38722.41603063, 2795.03725199, 13785.78482646,\n", " 10351.84377925, 13785.78482646, 10351.84377925, 13785.78482646,\n", " 6161.11114528, 2795.03725199, 10351.84377925, 13785.78482646,\n", " 10351.84377925, 2795.03725199, 10351.84377925, 6161.11114528,\n", " 2795.03725199, 45606.72260404, 13785.78482646, 6161.11114528,\n", " 13785.78482646, 13785.78482646, 10351.84377925, 6161.11114528,\n", " 24603.2390669 , 2795.03725199, 13785.78482646, 13785.78482646,\n", " 13785.78482646, 6161.11114528, 6161.11114528, 6161.11114528,\n", " 13785.78482646, 6161.11114528, 10351.84377925, 6161.11114528,\n", " 2795.03725199, 13785.78482646, 2795.03725199, 10351.84377925,\n", " 38722.41603063, 13785.78482646, 13785.78482646, 45606.72260404,\n", " 13785.78482646, 24603.2390669 , 6161.11114528, 6161.11114528,\n", " 6161.11114528, 6161.11114528, 24603.2390669 , 10351.84377925,\n", " 6161.11114528, 6161.11114528, 6161.11114528, 10351.84377925,\n", " 24603.2390669 , 38722.41603063, 10351.84377925, 2795.03725199,\n", " 6161.11114528, 2795.03725199, 45606.72260404, 2795.03725199,\n", " 24603.2390669 , 2795.03725199, 2795.03725199, 6161.11114528,\n", " 10351.84377925, 13785.78482646, 10351.84377925, 10351.84377925,\n", " 13785.78482646, 18723.54895898, 10351.84377925, 2795.03725199])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions_3 = reg3.predict(X_test)\n", "predictions_3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compute the mean squared error for these predictions:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "27792662.098281976" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "((y_test - predictions_3)**2).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What is the mean squared error if you use `max_depth = 4`?" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "29737038.361192185" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg4 = tree.DecisionTreeRegressor(max_depth = 4)\n", "reg4 = reg4.fit(X_train, y_train)\n", "predictions_4 = reg4.predict(X_test)\n", "((y_test - predictions_4)**2).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What is the mean squared error if you use `max_depth = 5`?" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "32480100.453982316" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg5 = tree.DecisionTreeRegressor(max_depth = 5)\n", "reg5 = reg5.fit(X_train, y_train)\n", "predictions_5 = reg5.predict(X_test)\n", "((y_test - predictions_5)**2).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What about if you use `max_depth = 2`?" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "32211767.651123475" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg2 = tree.DecisionTreeRegressor(max_depth = 2)\n", "reg2 = reg2.fit(X_train, y_train)\n", "predictions_2 = reg2.predict(X_test)\n", "((y_test - predictions_2)**2).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Which `max_depth` parameter should you use? What is the corresponding decision tree?\n", "\n", "You can also use a loop to quickly check the different parameter values for `max_depth`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dot_data = tree.export_graphviz(reg_depth3, out_file=None) \n", "graph = graphviz.Source(dot_data) \n", "graph.render(\"insurance_depth3.dot\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open (\"insurance_depth3.dot\", \"r\") as fin:\n", " with open(\"insurance_depth3_fixed.dot\",\"w\") as fout:\n", " for line in fin.readlines():\n", " line = line.replace(\"X[0]\",\"age\")\n", " line = line.replace(\"X[1]\",\"bmi\")\n", " line = line.replace(\"X[2]\",\"children\")\n", " line = line.replace(\"X[3]\",\"sex_male\")\n", " line = line.replace(\"X[4]\",\"smoker_yes\")\n", " line = line.replace(\"X[5]\",\"region_northwest\") \n", " line = line.replace(\"X[4]\",\"region_southeast\")\n", " line = line.replace(\"X[5]\",\"region_southwest\")\n", " fout.write(line)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we can compare the mean squared error using a Decision Tree regressor to the mean squared error computed using linear regression in Lab 8, also based on a training/testing split of 0.2. It was 41142821.67547247 (for my training/testing data).\n", "\n", "Which model is better?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Return to the decision tree classifier from last lab. Which `max_depth` is best?" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.4.8" } }, "nbformat": 4, "nbformat_minor": 2 }